{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp losses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Losses\n",
    "\n",
    "> This contains losses not available in fastai or Pytorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "## Available in Pytorch 1.9\n",
    "class HuberLoss(nn.Module):\n",
    "    \"\"\"Huber loss \n",
    "    \n",
    "    Creates a criterion that uses a squared term if the absolute\n",
    "    element-wise error falls below delta and a delta-scaled L1 term otherwise.\n",
    "    This loss combines advantages of both :class:`L1Loss` and :class:`MSELoss`; the\n",
    "    delta-scaled L1 region makes the loss less sensitive to outliers than :class:`MSELoss`,\n",
    "    while the L2 region provides smoothness over :class:`L1Loss` near 0. See\n",
    "    `Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`_ for more information.\n",
    "    This loss is equivalent to nn.SmoothL1Loss when delta == 1.\n",
    "    \"\"\"\n",
    "    def __init__(self, reduction='mean', delta=1.0):\n",
    "        assert reduction in ['mean', 'sum', 'none'], \"You must set reduction to 'mean', 'sum' or 'none'\"\n",
    "        self.reduction, self.delta = reduction, delta\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n",
    "        diff = input - target\n",
    "        abs_diff = torch.abs(diff)\n",
    "        mask = abs_diff < self.delta\n",
    "        loss = torch.cat([(.5*diff[mask]**2), self.delta * (abs_diff[~mask] - (.5 * self.delta))])\n",
    "        if self.reduction == 'mean':\n",
    "            return loss.mean()\n",
    "        elif self.reduction == 'sum':\n",
    "            return loss.sum()\n",
    "        else: \n",
    "            return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class LogCoshLoss(nn.Module):\n",
    "    def __init__(self, reduction='mean', delta=1.0):\n",
    "        assert reduction in ['mean', 'sum', 'none'], \"You must set reduction to 'mean', 'sum' or 'none'\"\n",
    "        self.reduction, self.delta = reduction, delta\n",
    "        super().__init__()\n",
    "        \n",
    "    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n",
    "        loss = torch.log(torch.cosh(input - target + 1e-12))\n",
    "        if self.reduction == 'mean':\n",
    "            return loss.mean()\n",
    "        elif self.reduction == 'sum':\n",
    "            return loss.sum()\n",
    "        else: \n",
    "            return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4650)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inp = torch.rand(8, 3, 10)\n",
    "targ = torch.randn(8, 3, 10)\n",
    "test_close(HuberLoss(delta=1)(inp, targ), nn.SmoothL1Loss()(inp, targ))\n",
    "LogCoshLoss()(inp, targ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src onerror=\"\n",
       "        this.nextElementSibling.focus();\n",
       "        this.dispatchEvent(new KeyboardEvent('keydown', {key:'s', keyCode: 83, metaKey: true}));\n",
       "        \" style=\"display:none\"><input style=\"width:0;height:0;border:0\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide\n",
    "from tsai.imports import create_scripts\n",
    "from tsai.export import get_nb_name\n",
    "nb_name = get_nb_name()\n",
    "create_scripts(nb_name);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
